package com.luo.d3s.ext.aop.exception.handler;

import com.luo.d3s.core.exception.BaseException;
import com.luo.d3s.core.exception.ValidateException;
import com.luo.d3s.core.util.validation.Validates;
import com.luo.d3s.ext.aop.config.D3sAopConsts;
import com.luo.d3s.ext.aop.config.D3sAopProps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.validation.BindException;
import org.springframework.validation.FieldError;
import org.springframework.web.bind.MethodArgumentNotValidException;

import javax.validation.ConstraintViolationException;
import java.lang.reflect.Method;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

/**
 * 默认异常处理器实现
 *
 * @author luohq
 * @date 2023-02-01 10:36
 */
public class D3sDefaultExceptionHandler implements D3sExceptionHandler {

    private static final Logger log = LoggerFactory.getLogger(D3sDefaultExceptionHandler.class);

    private D3sAopProps.ExceptionProps exceptionProps;

    /**
     * Map(异常类型, 异常处理器) - 有序集合，保证异常处理顺序
     */
    protected final Map<Class<? extends Throwable>, BiFunction<Method, Throwable, Object>> exToHandlerMap = new LinkedHashMap<>(5);

    public D3sDefaultExceptionHandler(D3sAopProps.ExceptionProps exceptionProps) {
        this.exceptionProps = exceptionProps;
        //优先添加客户端扩展的异常处理器
        this.prependPutExToHandler();
        //初始异常处理器
        this.initExToFuncMap();
    }

    @Override
    public Object handleException(Method method, Throwable throwable) {
        if (!this.isReturnD3sResponse(method)) {
            log.error("Handle Exception without returning d3s response!", throwable);
            return null;
        }
        return this.exToHandlerMap.keySet().stream()
                .filter(exClass -> exClass.isAssignableFrom(throwable.getClass()))
                .findFirst()
                .map(exClass -> this.exToHandlerMap.get(exClass))
                .map(exHandler -> exHandler.apply(method, throwable))
                .orElse(null);
    }

    /**
     * 优先添加扩展的异常处理器（空实现,可由扩展方自行实现）
     */
    protected void prependPutExToHandler() {
        //this.putExToHandler(MethodArgumentNotValidException.class, this::handleMethodArgumentNotValidException);
    }

    /**
     * 初始Map(异常类型, 异常处理器)
     */
    protected void initExToFuncMap() {
        this.putExToHandler(BaseException.class, this::handleBaseException);
        this.putExToHandler(MethodArgumentNotValidException.class, this::handleMethodArgumentNotValidException);
        this.putExToHandler(ConstraintViolationException.class, this::handleConstraintViolations);
        this.putExToHandler(Throwable.class, this::handleThrowable);
    }

    /**
     * 添加异常处理器
     *
     * @param exClass              异常类型Class
     * @param concreteEHandlerFunc 异常处理函数
     * @param <T>                  异常类型
     */
    protected <T extends Throwable> void putExToHandler(Class<T> exClass, BiFunction<Method, T, Object> concreteEHandlerFunc) {
        this.exToHandlerMap.put(exClass, (method, throwable) -> concreteEHandlerFunc.apply(method, exClass.cast(throwable)));
    }


    /**
     * 处理BaseException并返回结果
     *
     * @param method 执行方法
     * @param be     基础异常
     * @return 响应结果
     */
    protected Object handleBaseException(Method method, BaseException be) {
        Boolean isValidateExAndDisableMsg = (be instanceof ValidateException && !this.exceptionProps.getEnableValidationMessage());
        String errMsg = isValidateExAndDisableMsg ? null : be.getMessage();
        return this.wrapErrorResp(method, be.getErrCode(), errMsg);
    }


    /**
     * 处理BindException并返回结果<br/>
     * <b>form参数（对象参数，没有加@RequestBody）触发</b>
     *
     * @param method 执行方法
     * @param be     绑定异常
     * @return 响应结果
     */
    protected Object handleBindException(Method method, BindException be) {
        return this.wrapErrorResp(method, this.exceptionProps.getDefaultValidationErrorCode(), this.convertFiledErrors(be.getFieldErrors()));
    }

    /**
     * 处理MethodArgumentNotValidException并返回结果<br/>
     * <b>在@RequestBody上添加@Validated处触发</b>
     *
     * @param method 执行方法
     * @param me     方法参数验证异常
     * @return 响应结果
     */
    protected Object handleMethodArgumentNotValidException(Method method, MethodArgumentNotValidException me) {
        return this.wrapErrorResp(method, this.exceptionProps.getDefaultValidationErrorCode(), this.convertFiledErrors(me.getBindingResult().getFieldErrors()));
    }


    /**
     * 转换ConstraintViolationException异常为错误响应结果<br/>
     * <b>@Validated加在controller类上，且在参数列表中直接指定constraints时触发</b>
     *
     * @param method 执行方法
     * @param ce     验证异常
     * @return 响应结果
     */
    protected Object handleConstraintViolations(Method method, ConstraintViolationException ce) {
        String validateErrMsg = Optional.ofNullable(ce.getConstraintViolations())
                .filter(constraintViolations -> this.exceptionProps.getEnableValidationMessage())
                .map(Validates::convertConstraintViolationsUnknownType)
                .orElse(null);
        return this.wrapErrorResp(method, this.exceptionProps.getDefaultValidationErrorCode(), validateErrMsg);
    }

    /**
     * 处理Throwable并返回结果
     *
     * @param method    执行方法
     * @param throwable 异常
     * @return 响应结果
     */
    protected Object handleThrowable(Method method, Throwable throwable) {
        log.error("Handle unknown exception", throwable);
        return this.wrapErrorResp(method, "SYS_EXCEPTION", "SYS_EXCEPTION");
    }

    /**
     * 转换FieldError列表为错误提示信息
     *
     * @param fieldErrors 属性错误列表
     * @return 错误提示信息
     */
    protected String convertFiledErrors(List<FieldError> fieldErrors) {
        return Optional.ofNullable(fieldErrors)
                .filter(fieldErrorsInner -> this.exceptionProps.getEnableValidationMessage())
                .map(fieldErrorsInner -> fieldErrorsInner.stream()
                        .map(fieldError -> String.format("%s:%s", fieldError.getField(), fieldError.getDefaultMessage()))
                        .collect(Collectors.joining(D3sAopConsts.SEPARATOR_COMMA)))
                .orElse(null);
    }
}
